STROKE PREDICTION

Muhammad Faiz
3/3/2021

linkedin : https://www.linkedin.com/in/faiz-arif/
Instagram : @iztagram

According to the World Health Organization (WHO) stroke is the 2nd leading cause of death globally, responsible for approximately 11% of total deaths. so we need a machine learning model to find out earlier people who have had a stroke. But before processing the data, you should know what a stroke is

Whats is Stroke?

In [43]:
from PIL import Image
Image.open('illustration_of_a_stroke.jpg').resize((800,300))
Out[43]:

According to stroke.org, stroke is a disease that affects the arteries leading to and within the brain. It is the No. 5 cause of death and a leading cause of disability in the United States.

A stroke occurs when a blood vessel that carries oxygen and nutrients to the brain is either blocked by a clot or bursts (or ruptures). When that happens, part of the brain cannot get the blood (and oxygen) it needs, so it and brain cells die.

What are the effects of stroke?

The brain is an extremely complex organ that controls various body functions. If a stroke occurs and blood flow can't reach the region that controls a particular body function, that part of the body won't work as it should.

<div style=”page-break-after: always;”></div>

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from plotly.subplots import make_subplots
import plotly.graph_objs as go
import plotly.express as px
import seaborn as sns

Data Introduction

the data is sourced from the Kaggle dataset, which has several parameters or variables as part of the patient-patient information

In [2]:
df = pd.read_csv("dataset/healthcare-dataset-stroke-data.csv")

df.info()
df.head(3)
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5110 entries, 0 to 5109
Data columns (total 12 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
 0   id                 5110 non-null   int64  
 1   gender             5110 non-null   object 
 2   age                5110 non-null   float64
 3   hypertension       5110 non-null   int64  
 4   heart_disease      5110 non-null   int64  
 5   ever_married       5110 non-null   object 
 6   work_type          5110 non-null   object 
 7   Residence_type     5110 non-null   object 
 8   avg_glucose_level  5110 non-null   float64
 9   bmi                4909 non-null   float64
 10  smoking_status     5110 non-null   object 
 11  stroke             5110 non-null   int64  
dtypes: float64(3), int64(4), object(5)
memory usage: 479.2+ KB
Out[2]:
id gender age hypertension heart_disease ever_married work_type Residence_type avg_glucose_level bmi smoking_status stroke
0 9046 Male 67.0 0 1 Yes Private Urban 228.69 36.6 formerly smoked 1
1 51676 Female 61.0 0 0 Yes Self-employed Rural 202.21 NaN never smoked 1
2 31112 Male 80.0 0 1 Yes Private Rural 105.92 32.5 never smoked 1

The data comes from Kaggle, which is a collection of patient data with the following variables:
1) id: unique identifier
2) gender: "Male", "Female" or "Other"
3) age: age of the patient
4) hypertension: 0 if the patient doesn't have hypertension, 1 if the patient has hypertension
5) heart_disease: 0 if the patient doesn't have any heart diseases, 1 if the patient has a heart disease
6) ever_married: "No" or "Yes"
7) work_type: "children", "Govt_jov", "Never_worked", "Private" or "Self-employed"
8) Residence_type: "Rural" or "Urban"
9) avg_glucose_level: average glucose level in blood
10) bmi: body mass index
11) smoking_status: "formerly smoked", "never smoked", "smokes" or "Unknown"
12) stroke: 1 if the patient had a stroke or 0 if not
Note: "Unknown" in smoking_status means that the information is unavailable for this patient

In [3]:
df['age'] = df['age'].astype('int')

<div style=”page-break-after: always;”></div>

Exploratory Data

In [4]:
df_stroke = df.loc[df['stroke'] == 1]
In [5]:
plt.figure(figsize=(10,5))
plt.style.use("bmh")
sns.distplot(df_stroke['age'], color='gray', kde=True, kde_kws={"color": "k"})
plt.axvline(df_stroke['age'].mean(), color='chocolate', linestyle='-', linewidth=0.8)
min_ylim, max_ylim = plt.ylim()
plt.text(df_stroke['age'].mean()*1.05, max_ylim*0.95, 'Mean (μ): {:.2f}'.format(df_stroke['age'].mean()))
plt.xlabel("Age (in years)")
plt.title(f"Distribution of Ages")
plt.show()

Based on the histogram plot above, it can be concluded that the average stroke patient is 67-68 years old and most of these patients are 70-80 years old.

In [7]:
fig = make_subplots(
    rows=2, cols=2,subplot_titles=('','<b>Distribution Of Female Ages<b>','<b>Distribution Of Male Ages<b>','Residuals'),
    vertical_spacing=0.09,
    specs=[[{"type": "pie","rowspan": 2}       ,{"type": "histogram"}] ,
           [None                               ,{"type": "histogram"}]            ,                                      
          ]
)

fig.add_trace(
    go.Pie(values=df_stroke.gender.value_counts().values,labels=['<b>Female<b>','<b>Male<b>'],
           hole=0.3,pull=[0,0.08],marker_colors=['pink','lightblue'],textposition='inside'),
    row=1, col=1
)


fig.add_trace(
    go.Histogram(
        x=df_stroke.query('gender=="Female"').age,marker= dict(color='pink'),name='Female Ages'
    ),
    row=1, col=2
)


fig.add_trace(
    go.Histogram(
        x=df_stroke.query('gender=="Male"').age,marker= dict(color='lightblue'),name='Male Ages'
    ),
    row=2, col=2
)


fig.update_layout(
    height=800,
    showlegend=False,
    title_text="<b>Age-Sex Infrence<b>",
)

fig.show()

patients affected by stroke are mostly women. women and men aged 75-79 years are prone to stroke

<div style=”page-break-after: always;”></div>

In [8]:
temp = pd.pivot_table(df,
                     values = ['stroke'],
                     index = ['hypertension'],
                     columns = ['heart_disease'],
                     aggfunc = {'stroke':['count','mean']})


temp.columns = temp.columns.set_levels(['No','Yes'], level=2)
temp.index = pd.Index(['No','Yes'], name='Hypertension')

temp.style.set_properties(**{'background-color': 'khaki','border-color':'white'},
                         subset=[('stroke','mean','No'),('stroke','mean','Yes')])
Out[8]:
stroke
count mean
heart_disease No Yes No Yes
Hypertension
No 4400.000000 212.000000 0.033864 0.160377
Yes 434.000000 64.000000 0.122120 0.203125
In [9]:
px.imshow(
    temp.loc[:,('stroke','mean')],
    labels = dict(color='Stroke'),
    title = 'Stroke probability',
    color_continuous_scale = px.colors.sequential.Redor,
    **{'width':800, 'height':500})

someone who has heart disease and hypertension has a greater chance of having a stroke than having just one of them

In [10]:
fig, ax = plt.subplots(2,2, figsize=(15,7))
sns.countplot(df_stroke['ever_married'], ax=ax[0,0], palette='pastel').set(title='Have the patients ever been married?',xlabel=None)
sns.countplot(df_stroke['work_type'], ax=ax[0,1], palette='cubehelix').set(title='The patients type of work',xlabel=None)
sns.countplot(df_stroke['Residence_type'], ax=ax[1,0], palette='icefire').set(title='The patients type of residence',xlabel=None)
sns.countplot(df_stroke['smoking_status'], ax=ax[1,1], order=df_stroke['smoking_status'].value_counts().index).set(title='The patients smoking status',xlabel=None)
Out[10]:
[Text(0.5, 0, ''), Text(0.5, 1.0, 'The patients smoking status')]

Most stroke patients are married people with private occupations who do not smoke and live in an urban environment

<div style=”page-break-after: always;”></div>

In [11]:
fig = make_subplots(shared_yaxes =True,
                   rows=1, cols=2,
                   horizontal_spacing = 0.02,
                   subplot_titles = ('Average Glucose level','Body mass index(BMI)'))

for i in [0,1]:
    if i == 0:
        name = 'No'
        color = 'rgb(217,175,107)'
        group = 'g_No'
    else:
        name = 'Yes'
        color = 'rgb(204,80,62)'
        group = 'g_Yes'
        
    fig.add_trace(
        go.Histogram(
            x = df[df['stroke']==i]['avg_glucose_level'],
            nbinsx = 50,
            legendgroup = group,
            name = name,
            marker = dict(color = color),
            showlegend = False
        ),
        row=1, col=1
        
    )
    
    fig.add_trace(
        go.Histogram(
            x=df[df['stroke']==i]['bmi'],
            nbinsx = 50,
            legendgroup = group,
            name = name,
            marker = dict(color = color)
        ),
        row=1, col=2
    )
    
fig.update_layout(barmode='overlay', bargap=0)
fig.update_xaxes(row=1, col=1, title_text='Glucose level')
fig.update_xaxes(row=1, col=2, title_text='BMI')
fig.update_yaxes(row=1, col=1, title_text='count')
fig.update_layout(legend_title_text='Stroke')

fig.show()

From the histogram plot above, information can be taken that regardless of the average blood sugar and body weight of a person, there will always be a risk of having a stroke.

<div style=”page-break-after: always;”></div>

Preprocessing Data

Handle Missing Value (Na)

In [12]:
plt.title('Missing Value Graph', fontweight='bold')
ax = sns.heatmap(df.isna().sum().to_frame(),annot=True,fmt='d',cmap='GnBu')
ax.set_xlabel('Amount Missing')
ax.figure.set_size_inches(8.5, 3.5)
plt.show()

na_percentage = (df['bmi'].isnull().values.sum() / df.shape[0]) * 100
print(f"NaN Percentage: {na_percentage:.2f}%")

na_stroke = (df_stroke.isnull().values.sum()/df_stroke.shape[0])*100
print(f"NaN with stroke target Percentage: {na_stroke:.2f}%")

na_nonstroke = (df[df['stroke'] == 0].isnull().values.sum()/df[df['stroke'] == 0].shape[0])*100
print(f"NaN with non stroke target Percentage: {na_nonstroke:.2f}%")
NaN Percentage: 3.93%
NaN with stroke target Percentage: 16.06%
NaN with non stroke target Percentage: 3.31%

From the histogram plot above, information can be taken that regardless of the average blood sugar and body weight of a person, there will always be a risk of having a stroke.

In [13]:
df = df.dropna()

<div style=”page-break-after: always;”></div>

In [14]:
go.Figure(data=[go.Pie(values=df.stroke.value_counts().values,labels=['<b>No Stroke<b>','<b>Stroke<b>'],
                       pull=[0.2,0.1],marker_colors=['darkblue','red'],textposition='inside')])

detected an imbalance of the target column, so that it will be resampled. but before that, first data that has more than one class will be made

Dummification

In [15]:
dummy_ctg = pd.get_dummies(df, columns=['work_type','Residence_type','smoking_status'])
dummy_ctg.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 4909 entries, 0 to 5109
Data columns (total 20 columns):
 #   Column                          Non-Null Count  Dtype  
---  ------                          --------------  -----  
 0   id                              4909 non-null   int64  
 1   gender                          4909 non-null   object 
 2   age                             4909 non-null   int32  
 3   hypertension                    4909 non-null   int64  
 4   heart_disease                   4909 non-null   int64  
 5   ever_married                    4909 non-null   object 
 6   avg_glucose_level               4909 non-null   float64
 7   bmi                             4909 non-null   float64
 8   stroke                          4909 non-null   int64  
 9   work_type_Govt_job              4909 non-null   uint8  
 10  work_type_Never_worked          4909 non-null   uint8  
 11  work_type_Private               4909 non-null   uint8  
 12  work_type_Self-employed         4909 non-null   uint8  
 13  work_type_children              4909 non-null   uint8  
 14  Residence_type_Rural            4909 non-null   uint8  
 15  Residence_type_Urban            4909 non-null   uint8  
 16  smoking_status_Unknown          4909 non-null   uint8  
 17  smoking_status_formerly smoked  4909 non-null   uint8  
 18  smoking_status_never smoked     4909 non-null   uint8  
 19  smoking_status_smokes           4909 non-null   uint8  
dtypes: float64(2), int32(1), int64(4), object(2), uint8(11)
memory usage: 417.1+ KB

<div style=”page-break-after: always;”></div>

Resampling

In [16]:
from sklearn.utils import resample
In [17]:
target_majority = df[df.stroke == 0]
target_minority = df[df.stroke == 1]

target_minority_upsampled = resample(target_minority,
                                     replace=True,     
                                     n_samples= 4700,  
                                     random_state=123) 

df_balance = pd.concat([target_majority, target_minority_upsampled])
In [18]:
df_balance['stroke'].value_counts()
Out[18]:
1    4700
0    4700
Name: stroke, dtype: int64

For Dummy dataset

In [19]:
target_majority = dummy_ctg[dummy_ctg.stroke == 0]
target_minority = dummy_ctg[dummy_ctg.stroke == 1]

target_minority_upsampled = resample(target_minority,
                                     replace=True,     
                                     n_samples= 4700,  
                                     random_state=123) 

dummy_ctg = pd.concat([target_majority, target_minority_upsampled])
In [20]:
dummy_ctg['stroke'].value_counts()
Out[20]:
1    4700
0    4700
Name: stroke, dtype: int64
In [21]:
dummy_ctg = dummy_ctg.rename(columns={"gender":"male"})
In [22]:
dummy_ctg = dummy_ctg.replace(['Yes','No','Male','Female'],[1,0,1,0])
In [23]:
dummy_ctg = dummy_ctg[dummy_ctg['male'] != 'Other']

<div style=”page-break-after: always;”></div>

change categorical colomn to numerical

In [25]:
dfnumeric = df_balance.replace(['Private','Self-employed','Govt_job','children','Never_worked'],[1,2,3,4,5])
dfnumeric = dfnumeric.replace(['Yes','No','Male','Female'],[1,0,1,0])
dfnumeric = dfnumeric.replace(['Rural','Urban'],[1,2])
dfnumeric = dfnumeric.replace(['never smoked','Unknown','formerly smoked','smokes'],[1,2,3,4])
In [26]:
dfnumeric
Out[26]:
id gender age hypertension heart_disease ever_married work_type Residence_type avg_glucose_level bmi smoking_status stroke
249 30669 1 3 0 0 0 4 1 95.12 18.0 2 0
250 30468 1 58 1 0 1 1 2 87.96 39.2 1 0
251 16523 0 8 0 0 0 1 2 110.89 17.6 2 0
252 56543 0 70 0 0 1 1 1 69.04 35.9 3 0
253 46136 1 14 0 0 0 5 1 161.28 19.1 2 0
... ... ... ... ... ... ... ... ... ... ... ... ...
100 12363 1 64 0 1 1 3 2 74.10 28.8 2 1
132 69551 1 69 1 0 0 1 1 182.99 36.5 1 1
63 19557 0 45 0 0 1 1 1 93.72 30.2 3 1
76 36236 1 80 1 0 1 1 2 240.09 27.0 1 1
53 47167 0 77 1 0 1 2 2 124.13 31.4 1 1

9400 rows × 12 columns

#df_balance.to_csv(r'C:\Users\User\Documents\Python Project\Stroke Prediction\StrokeReady_df.csv', index=False, header=True) #dummy_ctg.to_csv(r'C:\Users\User\Documents\Python Project\Stroke Prediction\StrokeDummy_df.csv', index=False, header=True) #dfnumeric.to_csv(r'C:\Users\User\Documents\Python Project\Stroke Prediction\StrokeNumeric_df.csv', index=False, header=True)

the above process shows 3 objects:

  • df_balance : dataframe that has been resampled successfully
  • dummy_ctg : dummy and resample dataframes
  • dfnumeric : a dataframe where all columns are numeric

Modeling will be done using more than one machine learning, logistic regression, decision tree, xgboost and svm.

note: modeling is done in separate files for page efficiency

<div style=”page-break-after: always;”></div>

In [27]:
lr = pd.read_csv("dataset/result/resultLR.csv")
dt = pd.read_csv("dataset/result/resultDT.csv")
xgb = pd.read_csv("dataset/result/resultXGB.csv")
svm = pd.read_csv("dataset/result/resultSVM.csv")
In [28]:
resultml = pd.concat([lr,svm,dt,xgb])
In [29]:
resultml_pivot = resultml.pivot(index='parameter',columns='algorithm', values='score').rename_axis(None)
In [30]:
import plotly
In [31]:
trace1 = go.Bar(x=resultml_pivot['Decision Tree'],y=resultml_pivot.index,name='Decision Tree',orientation='h')
trace2 = go.Bar(x=resultml_pivot['Logistic Regression'],y=resultml_pivot.index,name='Logistic Regression',orientation='h')
trace3 = go.Bar(x=resultml_pivot['SVM'],y=resultml_pivot.index,name='SVM',orientation='h')
trace4 = go.Bar(x=resultml_pivot['XGBoost'],y=resultml_pivot.index,name='XGBoost',orientation='h')

data = [trace1,trace2,trace3,trace4]
plot_result = plotly.offline.iplot({"data": data,
                                    "layout":go.Layout(barmode='group', title={'text': "Logistic Regression Model Performance",
                                                                               'y':0.9,'x':0.5,'xanchor': 'center','yanchor': 'top'},
                                                       width=900,height=500,),
                                    
})


plt.show()

the best model to use is the xgboost model, although the focus in this case is recall, but the accuracy of all targets cannot be ruled out.